{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "452cad0e",
   "metadata": {},
   "source": [
    "# Progress Bar for Ray Actors (tqdm)\n",
    "\n",
    "Tracking progress of distributed tasks can be tricky.\n",
    "\n",
    "This script will demonstrate how to implement a simple\n",
    "progress bar for a Ray actor to track progress across various\n",
    "different distributed components.\n",
    "\n",
    "Original source: [Link](https://github.com/votingworks/arlo-e2e)\n",
    "\n",
    "## Setup: Dependencies\n",
    "\n",
    "First, import some dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccdf35fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspiration: https://github.com/honnibal/spacy-ray/pull/\n",
    "# 1/files#diff-7ede881ddc3e8456b320afb958362b2aR12-R45\n",
    "from asyncio import Event\n",
    "from typing import Tuple\n",
    "from time import sleep\n",
    "\n",
    "import ray\n",
    "\n",
    "# For typing purposes\n",
    "from ray.actor import ActorHandle\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bfe2866",
   "metadata": {},
   "source": [
    "This is the Ray \"actor\" that can be called from anywhere to update\n",
    "our progress. You'll be using the `update` method. Don't\n",
    "instantiate this class yourself. Instead,\n",
    "it's something that you'll get from a `ProgressBar`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97f4714",
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class ProgressBarActor:\n",
    "    counter: int\n",
    "    delta: int\n",
    "    event: Event\n",
    "\n",
    "    def __init__(self) -> None:\n",
    "        self.counter = 0\n",
    "        self.delta = 0\n",
    "        self.event = Event()\n",
    "\n",
    "    def update(self, num_items_completed: int) -> None:\n",
    "        \"\"\"Updates the ProgressBar with the incremental\n",
    "        number of items that were just completed.\n",
    "        \"\"\"\n",
    "        self.counter += num_items_completed\n",
    "        self.delta += num_items_completed\n",
    "        self.event.set()\n",
    "\n",
    "    async def wait_for_update(self) -> Tuple[int, int]:\n",
    "        \"\"\"Blocking call.\n",
    "\n",
    "        Waits until somebody calls `update`, then returns a tuple of\n",
    "        the number of updates since the last call to\n",
    "        `wait_for_update`, and the total number of completed items.\n",
    "        \"\"\"\n",
    "        await self.event.wait()\n",
    "        self.event.clear()\n",
    "        saved_delta = self.delta\n",
    "        self.delta = 0\n",
    "        return saved_delta, self.counter\n",
    "\n",
    "    def get_counter(self) -> int:\n",
    "        \"\"\"\n",
    "        Returns the total number of complete items.\n",
    "        \"\"\"\n",
    "        return self.counter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29bcd48f",
   "metadata": {},
   "source": [
    "This is where the progress bar starts. You create one of these\n",
    "on the head node, passing in the expected total number of items,\n",
    "and an optional string description.\n",
    "Pass along the `actor` reference to any remote task,\n",
    "and if they complete ten\n",
    "tasks, they'll call `actor.update.remote(10)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5737bc61",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Back on the local node, once you launch your remote Ray tasks, call\n",
    "# `print_until_done`, which will feed everything back into a `tqdm` counter.\n",
    "\n",
    "\n",
    "class ProgressBar:\n",
    "    progress_actor: ActorHandle\n",
    "    total: int\n",
    "    description: str\n",
    "    pbar: tqdm\n",
    "\n",
    "    def __init__(self, total: int, description: str = \"\"):\n",
    "        # Ray actors don't seem to play nice with mypy, generating\n",
    "        # a spurious warning for the following line,\n",
    "        # which we need to suppress. The code is fine.\n",
    "        self.progress_actor = ProgressBarActor.remote()  # type: ignore\n",
    "        self.total = total\n",
    "        self.description = description\n",
    "\n",
    "    @property\n",
    "    def actor(self) -> ActorHandle:\n",
    "        \"\"\"Returns a reference to the remote `ProgressBarActor`.\n",
    "\n",
    "        When you complete tasks, call `update` on the actor.\n",
    "        \"\"\"\n",
    "        return self.progress_actor\n",
    "\n",
    "    def print_until_done(self) -> None:\n",
    "        \"\"\"Blocking call.\n",
    "\n",
    "        Do this after starting a series of remote Ray tasks, to which you've\n",
    "        passed the actor handle. Each of them calls `update` on the actor.\n",
    "        When the progress meter reaches 100%, this method returns.\n",
    "        \"\"\"\n",
    "        pbar = tqdm(desc=self.description, total=self.total)\n",
    "        while True:\n",
    "            delta, counter = ray.get(self.actor.wait_for_update.remote())\n",
    "            pbar.update(delta)\n",
    "            if counter >= self.total:\n",
    "                pbar.close()\n",
    "                return"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa3a8a67",
   "metadata": {},
   "source": [
    "This is an example of a task that increments the progress bar.\n",
    "Note that this is a Ray Task, but it could very well\n",
    "be any generic Ray Actor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2afb6951",
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "def sleep_then_increment(i: int, pba: ActorHandle) -> int:\n",
    "    sleep(i / 2.0)\n",
    "    pba.update.remote(1)\n",
    "    return i"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "563238a0",
   "metadata": {},
   "source": [
    "Now you can run it and see what happens!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed87603d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run():\n",
    "    ray.init()\n",
    "    num_ticks = 6\n",
    "    pb = ProgressBar(num_ticks)\n",
    "    actor = pb.actor\n",
    "    # You can replace this with any arbitrary Ray task/actor.\n",
    "    tasks_pre_launch = [\n",
    "        sleep_then_increment.remote(i, actor) for i in range(0, num_ticks)\n",
    "    ]\n",
    "\n",
    "    pb.print_until_done()\n",
    "    tasks = ray.get(tasks_pre_launch)\n",
    "\n",
    "    tasks == list(range(num_ticks))\n",
    "    num_ticks == ray.get(actor.get_counter.remote())\n",
    "\n",
    "\n",
    "run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}